{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /usr/local/lib/python3.8/dist-packages/tensorflow/python/compat/v2_compat.py:96: disable_resource_variables (from tensorflow.python.ops.variable_scope) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "non-resource variables are not supported in the long term\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import tensorflow.compat.v1 as tf\n",
    "\n",
    "from open_spiel.python import rl_environment\n",
    "from open_spiel.python.pytorch import dqn as dqn_pt\n",
    "from open_spiel.python.algorithms import dqn\n",
    "from open_spiel.python.algorithms import random_agent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def eval_against_random_bots(env, trained_agents, random_agents, num_episodes):\n",
    "  \"\"\"Evaluates `trained_agents` against `random_agents` for `num_episodes`.\"\"\"\n",
    "  num_players = len(trained_agents)\n",
    "  sum_episode_rewards = np.zeros(num_players)\n",
    "  for player_pos in range(num_players):\n",
    "    cur_agents = random_agents[:]\n",
    "    cur_agents[player_pos] = trained_agents[player_pos]\n",
    "    for _ in range(num_episodes):\n",
    "      time_step = env.reset()\n",
    "      episode_rewards = 0\n",
    "      while not time_step.last():\n",
    "        player_id = time_step.observations[\"current_player\"]\n",
    "        if env.is_turn_based:\n",
    "          agent_output = cur_agents[player_id].step(\n",
    "              time_step, is_evaluation=True)\n",
    "          action_list = [agent_output.action]\n",
    "        else:\n",
    "          agents_output = [\n",
    "              agent.step(time_step, is_evaluation=True) for agent in cur_agents\n",
    "          ]\n",
    "          action_list = [agent_output.action for agent_output in agents_output]\n",
    "        time_step = env.step(action_list)\n",
    "        episode_rewards += time_step.rewards[player_pos]\n",
    "      sum_episode_rewards[player_pos] += episode_rewards\n",
    "  return sum_episode_rewards / num_episodes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pt_main(game,\n",
    "            config,\n",
    "            checkpoint_dir,\n",
    "            num_train_episodes,\n",
    "            eval_every,\n",
    "            hidden_layers_sizes,\n",
    "            replay_buffer_capacity,\n",
    "            batch_size):\n",
    "  num_players = 2\n",
    "\n",
    "  env = rl_environment.Environment(game, **config)\n",
    "  info_state_size = env.observation_spec()[\"info_state\"][0]\n",
    "  num_actions = env.action_spec()[\"num_actions\"]\n",
    "\n",
    "  # random agents for evaluation\n",
    "  random_agents = [\n",
    "      random_agent.RandomAgent(player_id=idx, num_actions=num_actions)\n",
    "      for idx in range(num_players)\n",
    "  ]\n",
    "\n",
    "\n",
    "  hidden_layers_sizes = [int(l) for l in hidden_layers_sizes]\n",
    "  # pylint: disable=g-complex-comprehension\n",
    "  agents = [\n",
    "      dqn_pt.DQN(\n",
    "          player_id=idx,\n",
    "          state_representation_size=info_state_size,\n",
    "          num_actions=num_actions,\n",
    "          hidden_layers_sizes=hidden_layers_sizes,\n",
    "          replay_buffer_capacity=replay_buffer_capacity,\n",
    "          batch_size=batch_size) for idx in range(num_players)\n",
    "  ]\n",
    "  result = []\n",
    "  for ep in range(num_train_episodes):\n",
    "    if (ep + 1) % eval_every == 0:\n",
    "      r_mean = eval_against_random_bots(env, agents, random_agents, 1000)\n",
    "      result.append(r_mean)\n",
    "      print(\"[%s] Mean episode rewards %s\" %(ep + 1, r_mean))\n",
    "\n",
    "    time_step = env.reset()\n",
    "    while not time_step.last():\n",
    "      player_id = time_step.observations[\"current_player\"]\n",
    "      if env.is_turn_based:\n",
    "        agent_output = agents[player_id].step(time_step)\n",
    "        action_list = [agent_output.action]\n",
    "      else:\n",
    "        agents_output = [agent.step(time_step) for agent in agents]\n",
    "        action_list = [agent_output.action for agent_output in agents_output]\n",
    "      time_step = env.step(action_list)\n",
    "\n",
    "    # Episode is over, step all agents with final info state.\n",
    "    for agent in agents:\n",
    "      agent.step(time_step)\n",
    "  return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def tf_main(game,\n",
    "            config,\n",
    "            checkpoint_dir,\n",
    "            num_train_episodes,\n",
    "            eval_every,\n",
    "            hidden_layers_sizes,\n",
    "            replay_buffer_capacity,batch_size):\n",
    "  num_players = 2\n",
    "\n",
    "  env = rl_environment.Environment(game, **config)\n",
    "  info_state_size = env.observation_spec()[\"info_state\"][0]\n",
    "  num_actions = env.action_spec()[\"num_actions\"]\n",
    "\n",
    "  # random agents for evaluation\n",
    "  random_agents = [\n",
    "      random_agent.RandomAgent(player_id=idx, num_actions=num_actions)\n",
    "      for idx in range(num_players)\n",
    "  ]\n",
    "\n",
    "  with tf.Session() as sess:\n",
    "    hidden_layers_sizes = [int(l) for l in hidden_layers_sizes]\n",
    "    # pylint: disable=g-complex-comprehension\n",
    "    agents = [\n",
    "        dqn.DQN(\n",
    "            session=sess,\n",
    "            player_id=idx,\n",
    "            state_representation_size=info_state_size,\n",
    "            num_actions=num_actions,\n",
    "            hidden_layers_sizes=hidden_layers_sizes,\n",
    "            replay_buffer_capacity=replay_buffer_capacity,\n",
    "            batch_size=batch_size) for idx in range(num_players)\n",
    "    ]\n",
    "    saver = tf.train.Saver()\n",
    "    sess.run(tf.global_variables_initializer())\n",
    "    \n",
    "    result_tf = []\n",
    "    for ep in range(num_train_episodes):\n",
    "      if (ep + 1) % eval_every == 0:\n",
    "        r_mean = eval_against_random_bots(env, agents, random_agents, 1000)\n",
    "        result_tf.append(r_mean)\n",
    "        print(\"[%s] Mean episode rewards %s\" %(ep + 1, r_mean))\n",
    "\n",
    "      time_step = env.reset()\n",
    "      while not time_step.last():\n",
    "        player_id = time_step.observations[\"current_player\"]\n",
    "        if env.is_turn_based:\n",
    "          agent_output = agents[player_id].step(time_step)\n",
    "          action_list = [agent_output.action]\n",
    "        else:\n",
    "          agents_output = [agent.step(time_step) for agent in agents]\n",
    "          action_list = [agent_output.action for agent_output in agents_output]\n",
    "        time_step = env.step(action_list)\n",
    "\n",
    "      # Episode is over, step all agents with final info state.\n",
    "      for agent in agents:\n",
    "        agent.step(time_step)\n",
    "  return result_tf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpoint_dir = \"/tmp/dqn_test\"\n",
    "num_train_episodes = 10000\n",
    "eval_every = 100\n",
    "\n",
    "hidden_layers_sizes = [64, 64]\n",
    "replay_buffer_capacity = int(1e5)\n",
    "batch_size = 32"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# BREAKTHROUGH"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "game = \"breakthrough\"\n",
    "config = {\"columns\": 5, \"rows\": 5}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[100] Mean episode rewards [0.396 0.44 ]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.8/dist-packages/torch/autograd/__init__.py:130: UserWarning: CUDA initialization: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx (Triggered internally at  /pytorch/c10/cuda/CUDAFunctions.cpp:100.)\n",
      "  Variable._execution_engine.run_backward(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[200] Mean episode rewards [0.396 0.472]\n",
      "[300] Mean episode rewards [0.522 0.49 ]\n",
      "[400] Mean episode rewards [0.664 0.572]\n",
      "[500] Mean episode rewards [0.676 0.578]\n",
      "[600] Mean episode rewards [0.686 0.578]\n",
      "[700] Mean episode rewards [0.72  0.484]\n",
      "[800] Mean episode rewards [0.772 0.432]\n",
      "[900] Mean episode rewards [0.664 0.384]\n",
      "[1000] Mean episode rewards [0.616 0.332]\n",
      "[1100] Mean episode rewards [0.646 0.338]\n",
      "[1200] Mean episode rewards [0.586 0.348]\n",
      "[1300] Mean episode rewards [0.512 0.328]\n",
      "[1400] Mean episode rewards [0.474 0.324]\n",
      "[1500] Mean episode rewards [0.416 0.316]\n",
      "[1600] Mean episode rewards [0.548 0.266]\n",
      "[1700] Mean episode rewards [0.64  0.174]\n",
      "[1800] Mean episode rewards [0.622 0.226]\n",
      "[1900] Mean episode rewards [0.524 0.3  ]\n",
      "[2000] Mean episode rewards [0.568 0.132]\n",
      "[2100] Mean episode rewards [0.544 0.14 ]\n",
      "[2200] Mean episode rewards [0.698 0.06 ]\n",
      "[2300] Mean episode rewards [0.696 0.188]\n",
      "[2400] Mean episode rewards [0.726 0.346]\n",
      "[2500] Mean episode rewards [0.792 0.404]\n",
      "[2600] Mean episode rewards [0.876 0.512]\n",
      "[2700] Mean episode rewards [0.902 0.464]\n",
      "[2800] Mean episode rewards [0.802 0.444]\n",
      "[2900] Mean episode rewards [0.866 0.684]\n",
      "[3000] Mean episode rewards [0.884 0.654]\n",
      "[3100] Mean episode rewards [0.822 0.626]\n",
      "[3200] Mean episode rewards [0.836 0.716]\n",
      "[3300] Mean episode rewards [0.76  0.466]\n",
      "[3400] Mean episode rewards [0.662 0.708]\n",
      "[3500] Mean episode rewards [0.752 0.782]\n",
      "[3600] Mean episode rewards [0.648 0.662]\n",
      "[3700] Mean episode rewards [0.832 0.754]\n",
      "[3800] Mean episode rewards [0.794 0.792]\n",
      "[3900] Mean episode rewards [0.732 0.724]\n",
      "[4000] Mean episode rewards [0.882 0.648]\n",
      "[4100] Mean episode rewards [0.828 0.566]\n",
      "[4200] Mean episode rewards [0.904 0.654]\n",
      "[4300] Mean episode rewards [0.882 0.434]\n",
      "[4400] Mean episode rewards [0.886 0.636]\n",
      "[4500] Mean episode rewards [0.914 0.728]\n",
      "[4600] Mean episode rewards [0.954 0.712]\n",
      "[4700] Mean episode rewards [0.926 0.656]\n",
      "[4800] Mean episode rewards [0.888 0.78 ]\n",
      "[4900] Mean episode rewards [0.93 0.77]\n",
      "[5000] Mean episode rewards [0.95  0.764]\n",
      "[5100] Mean episode rewards [0.944 0.848]\n",
      "[5200] Mean episode rewards [0.978 0.642]\n",
      "[5300] Mean episode rewards [0.928 0.948]\n",
      "[5400] Mean episode rewards [0.952 0.804]\n",
      "[5500] Mean episode rewards [0.976 0.928]\n",
      "[5600] Mean episode rewards [0.98  0.916]\n",
      "[5700] Mean episode rewards [0.952 0.924]\n",
      "[5800] Mean episode rewards [0.962 0.94 ]\n",
      "[5900] Mean episode rewards [0.946 0.948]\n",
      "[6000] Mean episode rewards [0.958 0.914]\n",
      "[6100] Mean episode rewards [0.936 0.962]\n",
      "[6200] Mean episode rewards [0.95  0.962]\n",
      "[6300] Mean episode rewards [0.972 0.962]\n",
      "[6400] Mean episode rewards [0.91  0.952]\n",
      "[6500] Mean episode rewards [0.956 0.956]\n",
      "[6600] Mean episode rewards [0.976 0.932]\n",
      "[6700] Mean episode rewards [0.968 0.948]\n",
      "[6800] Mean episode rewards [0.98  0.946]\n",
      "[6900] Mean episode rewards [0.976 0.952]\n",
      "[7000] Mean episode rewards [0.982 0.95 ]\n",
      "[7100] Mean episode rewards [0.988 0.956]\n",
      "[7200] Mean episode rewards [0.984 0.948]\n",
      "[7300] Mean episode rewards [0.968 0.96 ]\n",
      "[7400] Mean episode rewards [0.978 0.97 ]\n",
      "[7500] Mean episode rewards [0.96  0.942]\n",
      "[7600] Mean episode rewards [0.966 0.968]\n",
      "[7700] Mean episode rewards [0.956 0.948]\n",
      "[7800] Mean episode rewards [0.976 0.962]\n",
      "[7900] Mean episode rewards [0.958 0.964]\n",
      "[8000] Mean episode rewards [0.966 0.942]\n",
      "[8100] Mean episode rewards [0.934 0.948]\n",
      "[8200] Mean episode rewards [0.95  0.952]\n",
      "[8300] Mean episode rewards [0.946 0.958]\n",
      "[8400] Mean episode rewards [0.974 0.94 ]\n",
      "[8500] Mean episode rewards [0.94  0.934]\n",
      "[8600] Mean episode rewards [0.958 0.952]\n",
      "[8700] Mean episode rewards [0.93  0.966]\n",
      "[8800] Mean episode rewards [0.968 0.94 ]\n",
      "[8900] Mean episode rewards [0.962 0.942]\n",
      "[9000] Mean episode rewards [0.946 0.95 ]\n",
      "[9100] Mean episode rewards [0.968 0.938]\n",
      "[9200] Mean episode rewards [0.962 0.95 ]\n",
      "[9300] Mean episode rewards [0.976 0.94 ]\n",
      "[9400] Mean episode rewards [0.98  0.948]\n",
      "[9500] Mean episode rewards [0.964 0.934]\n",
      "[9600] Mean episode rewards [0.97  0.922]\n",
      "[9700] Mean episode rewards [0.972 0.936]\n",
      "[9800] Mean episode rewards [0.966 0.932]\n",
      "[9900] Mean episode rewards [0.974 0.94 ]\n",
      "[10000] Mean episode rewards [0.966 0.928]\n"
     ]
    }
   ],
   "source": [
    "pt_result = pt_main(game,\n",
    "                    config,\n",
    "                    checkpoint_dir,\n",
    "                    num_train_episodes,\n",
    "                    eval_every,\n",
    "                    hidden_layers_sizes,\n",
    "                    replay_buffer_capacity,batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "ep = [x for x in range(len(pt_result))]\n",
    "pt_r_mean0 = [y[0] for y in pt_result]\n",
    "pt_r_mean1 = [y[1] for y in pt_result]\n",
    "\n",
    "plt.plot(ep,pt_r_mean0, c='red')\n",
    "plt.plot(ep,pt_r_mean1, c='blue')\n",
    "plt.xlabel('Episode')\n",
    "plt.ylabel('Mean episode rewards')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[100] Mean episode rewards [0.668 0.166]\n",
      "[200] Mean episode rewards [0.582 0.026]\n",
      "[300] Mean episode rewards [ 0.674 -0.1  ]\n",
      "[400] Mean episode rewards [ 0.556 -0.128]\n",
      "[500] Mean episode rewards [ 0.446 -0.142]\n",
      "[600] Mean episode rewards [ 0.322 -0.142]\n",
      "[700] Mean episode rewards [ 0.23  -0.214]\n",
      "[800] Mean episode rewards [ 0.292 -0.236]\n",
      "[900] Mean episode rewards [ 0.28  -0.298]\n",
      "[1000] Mean episode rewards [ 0.224 -0.298]\n",
      "[1100] Mean episode rewards [ 0.238 -0.306]\n",
      "[1200] Mean episode rewards [ 0.224 -0.37 ]\n",
      "[1300] Mean episode rewards [ 0.18  -0.404]\n",
      "[1400] Mean episode rewards [ 0.24  -0.388]\n",
      "[1500] Mean episode rewards [ 0.212 -0.36 ]\n",
      "[1600] Mean episode rewards [ 0.29  -0.302]\n",
      "[1700] Mean episode rewards [ 0.354 -0.37 ]\n",
      "[1800] Mean episode rewards [ 0.524 -0.33 ]\n",
      "[1900] Mean episode rewards [ 0.546 -0.242]\n",
      "[2000] Mean episode rewards [ 0.65  -0.162]\n",
      "[2100] Mean episode rewards [ 0.574 -0.158]\n",
      "[2200] Mean episode rewards [ 0.512 -0.068]\n",
      "[2300] Mean episode rewards [ 0.488 -0.204]\n",
      "[2400] Mean episode rewards [ 0.382 -0.004]\n",
      "[2500] Mean episode rewards [ 0.164 -0.006]\n",
      "[2600] Mean episode rewards [0.244 0.076]\n",
      "[2700] Mean episode rewards [0.332 0.21 ]\n",
      "[2800] Mean episode rewards [0.246 0.26 ]\n",
      "[2900] Mean episode rewards [0.23  0.572]\n",
      "[3000] Mean episode rewards [0.47  0.424]\n",
      "[3100] Mean episode rewards [0.524 0.35 ]\n",
      "[3200] Mean episode rewards [0.566 0.448]\n",
      "[3300] Mean episode rewards [0.494 0.396]\n",
      "[3400] Mean episode rewards [0.644 0.412]\n",
      "[3500] Mean episode rewards [0.666 0.606]\n",
      "[3600] Mean episode rewards [0.618 0.528]\n",
      "[3700] Mean episode rewards [0.676 0.734]\n",
      "[3800] Mean episode rewards [0.682 0.668]\n",
      "[3900] Mean episode rewards [0.794 0.784]\n",
      "[4000] Mean episode rewards [0.86 0.68]\n",
      "[4100] Mean episode rewards [0.768 0.82 ]\n",
      "[4200] Mean episode rewards [0.854 0.754]\n",
      "[4300] Mean episode rewards [0.912 0.768]\n",
      "[4400] Mean episode rewards [0.946 0.832]\n",
      "[4500] Mean episode rewards [0.918 0.744]\n",
      "[4600] Mean episode rewards [0.934 0.778]\n",
      "[4700] Mean episode rewards [0.95  0.792]\n",
      "[4800] Mean episode rewards [0.88 0.79]\n",
      "[4900] Mean episode rewards [0.956 0.83 ]\n",
      "[5000] Mean episode rewards [0.934 0.838]\n",
      "[5100] Mean episode rewards [0.948 0.882]\n",
      "[5200] Mean episode rewards [0.936 0.828]\n",
      "[5300] Mean episode rewards [0.906 0.848]\n",
      "[5400] Mean episode rewards [0.942 0.836]\n",
      "[5500] Mean episode rewards [0.94  0.876]\n",
      "[5600] Mean episode rewards [0.944 0.866]\n",
      "[5700] Mean episode rewards [0.954 0.868]\n",
      "[5800] Mean episode rewards [0.954 0.856]\n",
      "[5900] Mean episode rewards [0.95 0.86]\n",
      "[6000] Mean episode rewards [0.956 0.826]\n",
      "[6100] Mean episode rewards [0.938 0.888]\n",
      "[6200] Mean episode rewards [0.964 0.892]\n",
      "[6300] Mean episode rewards [0.956 0.902]\n",
      "[6400] Mean episode rewards [0.938 0.88 ]\n",
      "[6500] Mean episode rewards [0.972 0.854]\n",
      "[6600] Mean episode rewards [0.942 0.844]\n",
      "[6700] Mean episode rewards [0.936 0.868]\n",
      "[6800] Mean episode rewards [0.952 0.878]\n",
      "[6900] Mean episode rewards [0.944 0.904]\n",
      "[7000] Mean episode rewards [0.96  0.932]\n",
      "[7100] Mean episode rewards [0.954 0.892]\n",
      "[7200] Mean episode rewards [0.948 0.944]\n",
      "[7300] Mean episode rewards [0.968 0.902]\n",
      "[7400] Mean episode rewards [0.936 0.898]\n",
      "[7500] Mean episode rewards [0.966 0.898]\n",
      "[7600] Mean episode rewards [0.954 0.908]\n",
      "[7700] Mean episode rewards [0.974 0.902]\n",
      "[7800] Mean episode rewards [0.966 0.888]\n",
      "[7900] Mean episode rewards [0.966 0.888]\n",
      "[8000] Mean episode rewards [0.954 0.914]\n",
      "[8100] Mean episode rewards [0.962 0.928]\n",
      "[8200] Mean episode rewards [0.954 0.924]\n",
      "[8300] Mean episode rewards [0.92  0.902]\n",
      "[8400] Mean episode rewards [0.932 0.936]\n",
      "[8500] Mean episode rewards [0.964 0.908]\n",
      "[8600] Mean episode rewards [0.956 0.926]\n",
      "[8700] Mean episode rewards [0.932 0.916]\n",
      "[8800] Mean episode rewards [0.938 0.918]\n",
      "[8900] Mean episode rewards [0.926 0.948]\n",
      "[9000] Mean episode rewards [0.912 0.944]\n",
      "[9100] Mean episode rewards [0.926 0.944]\n",
      "[9200] Mean episode rewards [0.916 0.952]\n",
      "[9300] Mean episode rewards [0.926 0.94 ]\n",
      "[9400] Mean episode rewards [0.884 0.906]\n",
      "[9500] Mean episode rewards [0.914 0.922]\n",
      "[9600] Mean episode rewards [0.918 0.922]\n",
      "[9700] Mean episode rewards [0.928 0.936]\n",
      "[9800] Mean episode rewards [0.946 0.934]\n",
      "[9900] Mean episode rewards [0.926 0.924]\n",
      "[10000] Mean episode rewards [0.926 0.918]\n"
     ]
    }
   ],
   "source": [
    "result_tf = tf_main(game,\n",
    "                    config,\n",
    "                    checkpoint_dir,\n",
    "                    num_train_episodes,\n",
    "                    eval_every,\n",
    "                    hidden_layers_sizes,\n",
    "                    replay_buffer_capacity,batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ep = [x for x in range(len(result_tf))]\n",
    "tf_r_mean0 = [y[0] for y in result_tf]\n",
    "tf_r_mean1 = [y[1] for y in result_tf]\n",
    "\n",
    "plt.plot(ep,tf_r_mean0, c='red')\n",
    "plt.plot(ep,tf_r_mean1, c='blue')\n",
    "plt.xlabel('Episode')\n",
    "plt.ylabel('Mean episode rewards')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(ep,pt_r_mean0, c='skyblue')\n",
    "plt.plot(ep,pt_r_mean1, c='skyblue', linestyle='dashed')\n",
    "plt.plot(ep,tf_r_mean0, c='pink')\n",
    "plt.plot(ep,tf_r_mean1, c='pink', linestyle='dashed')\n",
    "plt.xlabel('Episode')\n",
    "plt.ylabel('Mean episode rewards')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# TIC-TAC-TOE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "game = \"tic_tac_toe\"\n",
    "config = {}\n",
    "num_train_episodes = 20000\n",
    "eval_every = 1000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1000] Mean episode rewards [0.823 0.112]\n",
      "[2000] Mean episode rewards [0.769 0.089]\n",
      "[3000] Mean episode rewards [0.883 0.161]\n",
      "[4000] Mean episode rewards [0.723 0.229]\n",
      "[5000] Mean episode rewards [0.424 0.125]\n",
      "[6000] Mean episode rewards [0.54  0.246]\n",
      "[7000] Mean episode rewards [0.637 0.244]\n",
      "[8000] Mean episode rewards [0.794 0.236]\n",
      "[9000] Mean episode rewards [0.643 0.148]\n",
      "[10000] Mean episode rewards [0.813 0.148]\n",
      "[11000] Mean episode rewards [0.626 0.17 ]\n",
      "[12000] Mean episode rewards [0.622 0.188]\n",
      "[13000] Mean episode rewards [0.874 0.244]\n",
      "[14000] Mean episode rewards [0.856 0.183]\n",
      "[15000] Mean episode rewards [0.825 0.28 ]\n",
      "[16000] Mean episode rewards [0.82  0.361]\n",
      "[17000] Mean episode rewards [0.847 0.241]\n",
      "[18000] Mean episode rewards [0.869 0.296]\n",
      "[19000] Mean episode rewards [0.904 0.261]\n",
      "[20000] Mean episode rewards [0.858 0.277]\n"
     ]
    }
   ],
   "source": [
    "pt_result = pt_main(game,\n",
    "                    config,\n",
    "                    checkpoint_dir,\n",
    "                    num_train_episodes,\n",
    "                    eval_every,\n",
    "                    hidden_layers_sizes,\n",
    "                    replay_buffer_capacity,batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "ep = [x for x in range(len(pt_result))]\n",
    "pt_r_mean0 = [y[0] for y in pt_result]\n",
    "pt_r_mean1 = [y[1] for y in pt_result]\n",
    "\n",
    "plt.plot(ep,pt_r_mean0, c='red')\n",
    "plt.plot(ep,pt_r_mean1, c='blue')\n",
    "plt.xlabel('Episode')\n",
    "plt.ylabel('Mean episode rewards')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1000] Mean episode rewards [ 0.493 -0.204]\n",
      "[2000] Mean episode rewards [ 0.346 -0.244]\n",
      "[3000] Mean episode rewards [ 0.537 -0.054]\n",
      "[4000] Mean episode rewards [ 0.464 -0.059]\n",
      "[5000] Mean episode rewards [0.46  0.144]\n",
      "[6000] Mean episode rewards [0.442 0.08 ]\n",
      "[7000] Mean episode rewards [0.606 0.068]\n",
      "[8000] Mean episode rewards [0.447 0.165]\n",
      "[9000] Mean episode rewards [0.702 0.196]\n",
      "[10000] Mean episode rewards [0.694 0.227]\n",
      "[11000] Mean episode rewards [0.757 0.213]\n",
      "[12000] Mean episode rewards [0.829 0.149]\n",
      "[13000] Mean episode rewards [0.733 0.186]\n",
      "[14000] Mean episode rewards [0.8   0.318]\n",
      "[15000] Mean episode rewards [0.849 0.308]\n",
      "[16000] Mean episode rewards [0.789 0.198]\n",
      "[17000] Mean episode rewards [0.825 0.353]\n",
      "[18000] Mean episode rewards [0.804 0.367]\n",
      "[19000] Mean episode rewards [0.827 0.355]\n",
      "[20000] Mean episode rewards [0.796 0.368]\n"
     ]
    }
   ],
   "source": [
    "result_tf = tf_main(game,\n",
    "                    config,\n",
    "                    checkpoint_dir,\n",
    "                    num_train_episodes,\n",
    "                    eval_every,\n",
    "                    hidden_layers_sizes,\n",
    "                    replay_buffer_capacity,batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ep = [x for x in range(len(result_tf))]\n",
    "tf_r_mean0 = [y[0] for y in result_tf]\n",
    "tf_r_mean1 = [y[1] for y in result_tf]\n",
    "\n",
    "plt.plot(ep,tf_r_mean0, c='red')\n",
    "plt.plot(ep,tf_r_mean1, c='blue')\n",
    "plt.xlabel('Episode')\n",
    "plt.ylabel('Mean episode rewards')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(ep,pt_r_mean0, c='skyblue')\n",
    "plt.plot(ep,pt_r_mean1, c='skyblue', linestyle='dashed')\n",
    "plt.plot(ep,tf_r_mean0, c='pink')\n",
    "plt.plot(ep,tf_r_mean1, c='pink', linestyle='dashed')\n",
    "plt.xlabel('Episode')\n",
    "plt.ylabel('Mean episode rewards')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
